rm(list = ls())
require(plotROC)
require(pROC)
require(dplyr)
library(purrr)
library(tibble)
library(magrittr)

## Load data and RF results
load("Data/data_full.RData")
load("Results/RF_base.RData")
data <- data_full %>%
  dplyr::select(trigercrisis, cooperative, conflictual,
                cooperative_cum, conflictual_cum, 
                ratio_GDP , delta_ratio_GDP ,
                trade_depen , delta_trade_depen ,
                idealpointdistance , delta_idealpointdistance ,
                miliexp_ratio ,  delta_miliexp_ratio ,  milperson_ratio ,
                delta_milperson_ratio ,
                sfi_high, sfi_low , 
                num_baad_high , num_baad_low ,
                high_GDPPC_log , llow_GDPPC_log,  ccode1_degree, ccode2_degree,   
                regime_dur_max , regime_dur_min ,
                low_democracy , high_democracy , delta_low_democracy , delta_high_democracy , 
                lnminidist)%>% 
    na.omit() %>% 
    dplyr::ungroup()

## Get a confusion matrix
ModelResults = list(RF_base)
modelname = "Full data"
# Get the predicted and actual Y outcomes
Y_pred =lapply(ModelResults, function(x) FUN = predict(x, type = "response"))
Y_obs = lapply(ModelResults, function(x) FUN = x$y)
# Use ROC 
roc = Map(function(x, y) roc(x,y), Y_obs, Y_pred)
best_threshold_RF <- lapply(roc, function(x) FUN = coords(x, "best", "threshold"))
names(roc) <- modelname
roc_df = lapply(roc, function(x) FUN= data.frame(
plotx = 1- x$specificities,
ploty = x$sensitivities,
name = paste("AUC =",
       sprintf("%.3f",x$auc)))) %>%
map_df(., rbind, .id="modelname")
# Show the best threshold
best_threshold_RF
# [[1]]
# threshold specificity sensitivity
# 1 0.07224137   0.9999587           1

#### Results for Table -1
data$y_pred_RF <- predict(RF_base, type = "response")
best_threshold_rf <-as.numeric(best_threshold_RF[[1]][1])
table(data$trigercrisis, data$y_pred_RF >=best_threshold_rf)
#FALSE    TRUE
# 0 1090305      45
# 1       0      42
 #################################################################
rm(list = ls())
